import os.path as osp
import warnings
from .base import BaseModel
from ..smp import *
from PIL import Image
import torch


class Chameleon(BaseModel):

    INSTALL_REQ = False
    INTERLEAVE = True

    def __init__(self, model_path="facebook/chameleon-7b", **kwargs):
        try:
            from transformers import (
                ChameleonProcessor,
                ChameleonForConditionalGeneration,
            )
        except Exception as e:
            logging.critical("Please install the latest transformers.")
            raise e

        processor = ChameleonProcessor.from_pretrained(model_path)
        model = ChameleonForConditionalGeneration.from_pretrained(
            model_path, torch_dtype=torch.bfloat16
        )

        self.model = model.cuda().eval()
        self.processor = processor

    def generate_inner(self, message, dataset=None):
        content, images = "", []
        for x in message:
            if x["type"] == "text":
                content += x["value"]
            elif x["type"] == "image":
                content += "<image>\n"
                images.append(Image.open(x["value"]))

        inputs = self.processor(
            text=[content], images=images, padding=True, return_tensors="pt"
        ).to(device="cuda", dtype=torch.bfloat16)
        generate_ids = self.model.generate(**inputs, max_new_tokens=512)
        input_token_len = inputs.input_ids.shape[1]
        text = self.processor.batch_decode(
            generate_ids[:, input_token_len:],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )[0]
        return text
